# -*- coding: utf-8 -*-
"""
MPA 双网络差异法（公共子图 + Critical-only + 二值度SC门槛 + 零覆盖纳入）+ 灵敏度网格
====================================================================================

基准口径（统一用于描述/制图与 ALAAM 基准）：
  - Critical-only（分海区分期的高分位阈值）
  - P=0.85（Q85），SC门槛=Q40（基于空间二值度数）
  - 三期（2015/2020/2025）计算口径一致；ALAAM 面板仍只输出 2020/2025（考虑滞后）

同时自动输出“稳健性网格”（P ∈ {0.85,0.90,0.95} × SC_Q ∈ {0.40,0.50,0.60}），
含对 2015 的 SC 门槛统计（标签文件与计数），以及 2020/2025 的面板文件。

输出：
  - outputs_alaam/matrices/   ：空间/制度矩阵（权重/二值）
  - outputs_alaam/edgelists/  ：ArcGIS 边表（空间=全边；制度=强边）
  - outputs_alaam/nodes/      ：misfit_dual_diff_YYYY.csv（含 M_raw, M, SC_deg, covered_any 等）
  - outputs_alaam/summary/    ：topology_summary.csv；sensitivity_grid_summary.csv（含2015/2020/2025计数）
  - outputs_alaam/alaam/      ：
        - alaam_nodes_panel.csv （基准：P=0.85 & SC=Q40，仅含 2020/2025）
        - alaam_nodes_panel__critical_p{85,90,95}_scq{40,50,60}.csv（稳健性网格，2020/2025）
        - y2015_labels__critical_p{85,90,95}_scq{40,50,60}.csv（2015标签，含 SC 门槛）

依赖：pandas, numpy, networkx, openpyxl
"""

import os
import warnings
from pathlib import Path
from math import radians, sin, cos, asin, sqrt
import numpy as np
import pandas as pd
import networkx as nx

# =========================
# CONFIG
# =========================
DATA_DIR = Path(".")
OUT_DIR = DATA_DIR / "outputs_alaam"
for sub in ["matrices", "edgelists", "nodes", "summary", "alaam"]:
    (OUT_DIR / sub).mkdir(parents=True, exist_ok=True)

ABS_HINT = Path(r"D:\\Fragmentation of MPA governance\\mpa_network_misfit\\pythonProject1\\.venv")
DATA_DIR_CANDIDATES = [
    ABS_HINT,
    Path(os.getenv("MPA_DATA_DIR", "") or ""),
    Path.cwd(),
    Path(__file__).resolve().parent,
    Path(__file__).resolve().parent / "data",
]
DATA_DIR_CANDIDATES = [d for d in DATA_DIR_CANDIDATES if str(d).strip() != ""]

YEARS = [2015, 2020, 2025]
PANEL_TARGET_YEARS = [2020, 2025]  # ALAAM 仅用两期（考虑滞后）

# —— 空间网络 ——（Top-K 近邻 + 行归一 + 对称）
K_SPATIAL_NEIGHBORS = 8

# —— 制度网络（覆盖映射 + UICS 合成） ——
TARGET_BINARY_DENSITY = 0.08
LEVEL_BASE = 0.60
K0 = 3.0
SOFT_TAU_Q = 0.75

# —— 错配“基准口径”（统一：P=0.85 & SC=Q40） ——
MISFIT_MODE = "critical_only"          # 固定 Critical-only
HIGH_PCTL_MAIN = 0.85                  # 基准 Critical 阈值（分海区、分期）
SC_GATE_Q_MAIN = 0.40                  # 基准 SC 门槛（海区内二值度分位）
ABS_BASELINE = 0.10                    # M 阈值下限，避免极端分布

# —— 灵敏度网格 ——（包含 P=0.90 & SC=Q50）
SENS_PCTLS = [0.85, 0.90, 0.95]
SENS_SCQS  = [0.40, 0.50, 0.60]

# —— 其他 ——
EPS = 1e-9
INSTITUTION_LEVEL_WEIGHTS = {"National": 1.0, "Provincial": 0.7, "Municipal": 0.4}
def _normalize(d):
    s = float(sum(d.values()))
    return {k: (v/s if s>0 else v) for k,v in d.items()}
INSTITUTION_LEVEL_WEIGHTS = _normalize(INSTITUTION_LEVEL_WEIGHTS)

# ========================= 工具 =========================

def _resolve_xlsx_with_fallback(xlsx_path: Path) -> Path:
    xlsx_path = Path(xlsx_path)
    if xlsx_path.exists():
        return xlsx_path
    name = xlsx_path.name
    for d in DATA_DIR_CANDIDATES:
        try:
            cand = d / name
            if cand.exists():
                return cand
        except Exception:
            continue
    try:
        cand2 = OUT_DIR.parent / name
        if cand2.exists():
            return cand2
    except Exception:
        pass
    raise FileNotFoundError(f"MPA Excel not found: '{name}'. Tried: " + "; ".join(str(d / name) for d in DATA_DIR_CANDIDATES))

def first_existing_col(df: pd.DataFrame, candidates, required=True):
    for c in candidates:
        if c in df.columns:
            return c
    if required:
        raise KeyError(f"未找到字段：{candidates}")
    return None

def gini(arr):
    x = np.sort(np.array(arr, dtype=float))
    n = len(x)
    if n == 0: return np.nan
    mu = x.mean()
    if mu == 0: return 0.0
    r = np.arange(1, n+1)
    return float(1 - (2/(n-1)) * (np.sum((n+1 - r) * x) / (n * mu)))

# ========================= 读入与预处理 =========================

def load_mpa_attrs(xlsx_path: Path):
    xlsx_path = _resolve_xlsx_with_fallback(xlsx_path)
    df = pd.read_excel(xlsx_path)
    import re as _re
    def _norm(s):
        return _re.sub(r'[^0-9A-Za-z\u4e00-\u9fff]+', '', str(s).strip().lower())
    cols_norm_map = {_norm(c): c for c in df.columns}
    def _find_col(cands):
        for key in cands:
            k = _norm(key)
            if k in cols_norm_map:
                return cols_norm_map[k]
            for nk, orig in cols_norm_map.items():
                if nk.startswith(k):
                    return orig
        return None
    # ID
    id_cands  = ['mpa_id','mpaid','id','编号','保护地id','保护地编号']
    id_col = _find_col(id_cands)
    if id_col is not None and id_col != 'MPA_ID':
        df.rename(columns={id_col: 'MPA_ID'}, inplace=True)
    if 'MPA_ID' not in df.columns:
        for c in df.columns:
            if str(c).strip().lower() == 'index':
                df.rename(columns={c:'MPA_ID'}, inplace=True)
                break
    if 'MPA_ID' not in df.columns:
        df['MPA_ID'] = range(1, len(df)+1)
    # lon/lat（支持 DMS）
    lon_cands = ['lon','lng','long','longitude','lon_dd','longdd','x','xcoord','xcoordinate','经度','中心经度']
    lat_cands = ['lat','latitude','lat_dd','latdd','y','ycoord','ycoordinate','纬度','中心纬度']
    def _dms_to_dd(s):
        if pd.isna(s): return None
        try: return float(s)
        except Exception: pass
        s = str(s)
        m = _re.match(r'^\s*([+-]?\d+(?:\.\d+)?)\s*$', s)
        if m: return float(m.group(1))
        m = _re.match(r'^\s*([+-]?\d+)\D+(\d+)?\D+(\d+(?:\.\d+)?)?\D*\s*$', s)
        if m:
            deg = float(m.group(1)); minutes = float(m.group(2) or 0.0); seconds = float(m.group(3) or 0.0)
            sign = -1.0 if deg < 0 else 1.0; deg = abs(deg)
            return sign * (deg + minutes/60.0 + seconds/3600.0)
        parts = _re.split(r'[, ]+', s.strip())
        if len(parts) == 3 and all(_re.match(r'^[+-]?\d+(?:\.\d+)?$', p) for p in parts):
            deg, minutes, seconds = map(float, parts)
            sign = -1.0 if deg < 0 else 1.0; deg = abs(deg)
            return sign * (deg + minutes/60.0 + seconds/3600.0)
        return None
    lon_col = _find_col(lon_cands)
    lat_col = _find_col(lat_cands)
    if lon_col is not None: df['lon'] = df[lon_col].apply(_dms_to_dd)
    if lat_col is not None: df['lat'] = df[lat_col].apply(_dms_to_dd)
    if 'lon' not in df.columns:
        for c in df.columns:
            if _norm(c).startswith('longitude'):
                df['lon'] = df[c].apply(_dms_to_dd); break
    if 'lat' not in df.columns:
        for c in df.columns:
            if _norm(c).startswith('latitude'):
                df['lat'] = df[c].apply(_dms_to_dd); break
    if 'lon' not in df.columns or 'lat' not in df.columns:
        have = ', '.join(df.columns.astype(str).tolist())
        raise KeyError("Missing 'lon'/'lat' columns after normalization. Available columns: " + have)
    df['lon'] = pd.to_numeric(df['lon'], errors='coerce')
    df['lat'] = pd.to_numeric(df['lat'], errors='coerce')
    # province / sea_region 兜底
    prov_cands = ['province','prov','province_name','省','省份','所在省','行政省份','所属省份','省级行政区']
    sea_cands  = ['sea_region','searegion','sea','marine_region','海区','海域','所属海区','海区名称']
    prov_col = _find_col(prov_cands)
    sea_col  = _find_col(sea_cands)
    if prov_col is not None and prov_col != 'province':
        df.rename(columns={prov_col:'province'}, inplace=True)
    if sea_col is not None and sea_col != 'sea_region':
        df.rename(columns={sea_col:'sea_region'}, inplace=True)
    if 'province' not in df.columns: df['province'] = 'Unknown'
    if 'sea_region' not in df.columns: df['sea_region'] = 'Unknown'
    # city 可选
    if 'city' not in df.columns: df['city'] = 'Unknown'
    return df

# ========================= 空间网络 =========================

def haversine_km(lon1, lat1, lon2, lat2):
    R = 6371.0088
    lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])
    dlon = lon2 - lon1; dlat = lat2 - lat1
    a = sin(dlat/2)**2 + cos(lat1)*cos(lat2)*sin(dlon/2)**2
    c = 2 * asin(sqrt(a))
    return R * c

def build_spatial_W_topk(coords_df: pd.DataFrame, k=8):
    ids = list(coords_df.index)
    X = coords_df[["lon", "lat"]].values.astype(float)
    n = len(ids)
    if n < 2:
        return ids, np.zeros((n,n), dtype=float), np.zeros((n,n), dtype=float), np.zeros((n,n), dtype=int)
    D = np.zeros((n, n), dtype=float)
    for i in range(n):
        lon1, lat1 = X[i]
        for j in range(i+1, n):
            lon2, lat2 = X[j]
            d = haversine_km(lon1, lat1, lon2, lat2)
            D[i, j] = D[j, i] = d if np.isfinite(d) else np.inf
    W = np.zeros((n, n), dtype=float)
    for i in range(n):
        di = D[i].copy(); di[i] = np.inf
        kk = min(k, n-1)
        nn_idx = np.argpartition(di, kk)[:kk]
        w = 1.0 / np.maximum(di[nn_idx], EPS)
        s = w.sum()
        if s > 0:
            W[i, nn_idx] = w / s
    W = 0.5*(W + W.T)   # 对称化
    W = (W.T / (W.sum(axis=1) + EPS)).T  # 行归一保证 ≤1
    B = (W > 0).astype(int)
    return ids, W, D, B

# ========================= 制度网络 =========================

POLICY_ID_COLS    = ["Policy_ID", "policy_id", "POLICY_ID"]
POLICY_LEVEL_COLS = ["Policy_Level", "LEVEL", "level", "层级", "policy_level"]
POLICY_REGION_COLS= ["Region_std", "region_std", "REGION", "region"]

def std_level(level_str: str) -> str:
    if pd.isna(level_str): return "Other"
    s = str(level_str).strip().lower()
    if "national" in s or "国家" in s: return "National"
    if "prov" in s or "省" in s: return "Provincial"
    if "市" in s or "县" in s or "munic" in s or "city" in s: return "Municipal"
    return "Other"

def clean_region_token(tok: str) -> str:
    if pd.isna(tok): return ""
    s = str(tok).strip()
    for suf in ["省", "市", "自治区", "特别行政区", "区", "县"]:
        if s.endswith(suf):
            s = s[: -len(suf)]
    return s

def load_policy(year: int) -> pd.DataFrame:
    path = DATA_DIR / f"{year}POLICY_clean.csv"
    df = pd.read_csv(path)
    pid = first_existing_col(df, POLICY_ID_COLS)
    lvl = first_existing_col(df, POLICY_LEVEL_COLS)
    reg = first_existing_col(df, POLICY_REGION_COLS)
    agg = (
        df.groupby(df[pid], as_index=False)
          .agg({
              lvl:  lambda s: s.dropna().iloc[0] if s.dropna().size else np.nan,
              reg:  lambda s: " | ".join(sorted({x for x in s.dropna().astype(str)})) if s.notna().any() else np.nan
          })
    )
    agg.columns = ["Policy_ID", "Policy_Level", "Region_std"]
    agg["Policy_Level_std"] = agg["Policy_Level"].map(std_level)
    wmap = INSTITUTION_LEVEL_WEIGHTS
    agg["level_w"] = agg["Policy_Level_std"].map(wmap).fillna(0.0)
    return agg

def build_policy_coverage(mpa_df: pd.DataFrame, pol_df: pd.DataFrame):
    ids = mpa_df["MPA_ID"].astype(str).tolist()
    id2prov = mpa_df.set_index("MPA_ID")["province"].astype(str).to_dict() if "province" in mpa_df.columns else {}
    id2city = mpa_df.set_index("MPA_ID")["city"].astype(str).to_dict() if "city" in mpa_df.columns else {}

    prov2ids, city2ids = {}, {}
    for i, prov in id2prov.items():
        key = clean_region_token(prov)
        prov2ids.setdefault(key, set()).add(i)
    for i, city in id2city.items():
        key = clean_region_token(city)
        city2ids.setdefault(key, set()).add(i)

    pol_ids = pol_df["Policy_ID"].astype(str).tolist()
    levels  = pol_df["Policy_Level_std"].astype(str).tolist()
    level_w = pol_df["level_w"].astype(float).tolist()
    regions = pol_df["Region_std"].astype(str).tolist()

    N, P = len(ids), len(pol_ids)
    idx = {i: k for k, i in enumerate(ids)}
    Xw = np.zeros((N, P), dtype=float)
    Xb = np.zeros((N, P), dtype=int)

    for k in range(P):
        lev = levels[k]
        w   = level_w[k]
        reg = regions[k] if regions[k] != "nan" else None
        covered = set()
        if lev == "National":
            covered = set(ids)
        elif lev == "Provincial" and reg:
            for raw in reg.split("|"):
                token = clean_region_token(raw)
                if token in prov2ids:
                    covered |= prov2ids[token]
        elif lev == "Municipal" and reg and len(city2ids) > 0:
            for raw in reg.split("|"):
                token = clean_region_token(raw)
                if token in city2ids:
                    covered |= city2ids[token]
        for mpa_id in covered:
            Xw[idx[mpa_id], k] = w
            Xb[idx[mpa_id], k] = 1

    return Xw, Xb, ids, pol_ids, levels

# —— 制度相似度（权重制） ——

def institutional_co_mention_from_Xw(Xw: np.ndarray) -> np.ndarray:
    N = Xw.shape[0]
    W = np.zeros((N, N), dtype=float)
    for i in range(N):
        xi = Xw[i]
        for j in range(i+1, N):
            xj = Xw[j]
            W[i, j] = W[j, i] = float(np.minimum(xi, xj).sum())
    return W

def soft_coverage_matrix(Wco: np.ndarray, q=0.75):
    pos = Wco[np.triu_indices_from(Wco, 1)]
    pos = pos[pos > 0]
    if pos.size > 0:
        tau_q = float(np.quantile(pos, q))
        if tau_q <= 0: tau_q = float(pos.max())
    else:
        tau_q = 1.0
    C = np.minimum(1.0, Wco / (tau_q if tau_q > 0 else 1.0))
    np.fill_diagonal(C, 0.0)
    return C, float(tau_q)

# —— UICS：数量凹函数 + 层级因子（仅用于 A* 与拓扑） ——

def build_K_and_nonNat(Xb: np.ndarray, levels: list):
    N, P = Xb.shape
    K = Xb @ Xb.T
    nonnat_cols = np.array([1 if levels[c] != "National" else 0 for c in range(P)], dtype=int)
    Xb_nonNat = Xb[:, nonnat_cols==1]
    K_nonNat = Xb_nonNat @ Xb_nonNat.T if Xb_nonNat.shape[1] > 0 else np.zeros((N, N), dtype=int)
    NonNat = (K_nonNat >= 1).astype(float)
    np.fill_diagonal(K, 0); np.fill_diagonal(NonNat, 0)
    return K.astype(float), NonNat

def uics_unified_score(C_soft: np.ndarray, K: np.ndarray, NonNat: np.ndarray, k0=K0, level_base=LEVEL_BASE):
    qty = np.clip(K / k0, 0.0, 1.0)
    level_factor = level_base + (1.0 - level_base) * NonNat
    P = C_soft * qty * level_factor
    np.fill_diagonal(P, 0.0)
    return P

# —— 制度二值网络 ——

def binary_from_density_exact(W: np.ndarray, target_density: float):
    n = W.shape[0]
    iu = np.triu_indices(n, 1)
    vals = W[iu].astype(float)
    K = int(round(target_density * len(vals)))
    K = max(0, min(K, len(vals)))
    A = np.zeros_like(W, dtype=int)
    if K == 0 or len(vals) == 0:
        np.fill_diagonal(A, 0);  return A
    order = np.argsort(-vals, kind="mergesort")
    kth_val = vals[order[K-1]] if K-1 < len(order) else vals[order[-1]]
    sel = (vals > kth_val)
    if sel.sum() < K:
        equal_mask = (vals == kth_val)
        need = K - sel.sum()
        equal_idx = np.where(equal_mask)[0]
        pos = {idx: rank for rank, idx in enumerate(order)}
        equal_sorted = sorted(equal_idx, key=lambda x: pos[x])
        keep_equal = set(equal_sorted[:max(0, need)])
        new_sel = sel.copy()
        for idx in keep_equal:
            new_sel[idx] = True
        sel = new_sel
    A[iu] = sel.astype(int)
    A = A + A.T
    np.fill_diagonal(A, 0)
    return A

# ========================= 双网络差分：公共子图的 M_i^raw =========================

def dual_diff_Mi_raw(Ws: np.ndarray, Wi: np.ndarray) -> np.ndarray:
    """在公共子图上计算 M_i^raw。邻接并集；分母为0置0。"""
    n = Ws.shape[0]
    M_raw = np.zeros(n, dtype=float)
    for i in range(n):
        js = np.where((Ws[i] > 0) | (Wi[i] > 0))[0]
        js = js[js != i]
        if js.size == 0:
            M_raw[i] = 0.0
            continue
        s_ws = Ws[i, js]
        s_wi = Wi[i, js]
        num = float(np.sum(np.abs(s_ws - s_wi)))
        den = float(np.sum(s_ws + s_wi))
        M_raw[i] = (num / den) if den > 0 else 0.0
    return M_raw

# ========================= 拓扑与导出 =========================

def triangles_per_node(A: np.ndarray) -> np.ndarray:
    n = A.shape[0]
    tris = np.zeros(n, dtype=int)
    nbrs = [set(np.where(A[i] > 0)[0]) for i in range(n)]
    for i in range(n):
        Ni = nbrs[i]; cnt = 0
        for u in Ni:
            cnt += len(Ni & nbrs[u])
        tris[i] = cnt // 2
    return tris

def topology_metrics_from_adj(A: np.ndarray, ids, label: str, year: int):
    G = nx.from_numpy_array(A)
    mapping = {i: ids[i] for i in range(len(ids))}
    G = nx.relabel_nodes(G, mapping)
    n = G.number_of_nodes(); m = G.number_of_edges()
    density = nx.density(G)
    degs = np.array([d for _, d in G.degree()], dtype=float)
    avg_degree = degs.mean() if len(degs) > 0 else np.nan
    degree_var = degs.var(ddof=0) if len(degs) > 0 else np.nan
    degree_cv = (degs.std(ddof=0)/avg_degree) if avg_degree>0 else np.nan
    degree_gini = gini(degs)
    transitivity = nx.transitivity(G)
    avg_clust = nx.average_clustering(G) if n>0 else np.nan
    Gcc = None
    if n>0 and m>0:
        comps = sorted(nx.connected_components(G), key=len, reverse=True)
        Gcc = G.subgraph(comps[0]).copy()
    try:
        gcc_avg_pl = nx.average_shortest_path_length(Gcc) if (Gcc and Gcc.number_of_nodes()>1) else np.nan
    except Exception:
        gcc_avg_pl = np.nan
    try:
        diameter_gcc = nx.diameter(Gcc) if (Gcc and Gcc.number_of_nodes()>1) else np.nan
    except Exception:
        diameter_gcc = np.nan
    try:
        radius_gcc = nx.radius(Gcc) if (Gcc and Gcc.number_of_nodes()>1) else np.nan
    except Exception:
        radius_gcc = np.nan
    def algebraic_connectivity_numpy(Gin):
        if Gin is None or Gin.number_of_nodes()<2: return np.nan
        try:
            A_mat = nx.to_numpy_array(Gin, dtype=float)
        except Exception:
            return np.nan
        A_mat = np.maximum(A_mat, 0.0)
        deg = A_mat.sum(axis=1)
        L = np.diag(deg) - A_mat
        try:
            vals = np.linalg.eigvalsh(L)
        except Exception:
            return np.nan
        vals = np.sort(vals)
        return float(vals[1]) if len(vals) >= 2 else np.nan
    alg_conn = algebraic_connectivity_numpy(Gcc)
    global_eff = nx.global_efficiency(G)
    try:
        comms = list(nx.algorithms.community.greedy_modularity_communities(G))
        Q = nx.algorithms.community.modularity(G, comms)
        num_comm = len(comms)
    except Exception:
        Q, num_comm = np.nan, np.nan
    try:
        c = nx.eigenvector_centrality_numpy(G)
    except Exception:
        c = nx.eigenvector_centrality(G, max_iter=2000, tol=1e-06)
    v = np.array(list(c.values()), dtype=float)
    v = np.maximum(v, 0)
    s = v.sum()
    if s>0:
        p = v/s; eps=1e-12
        ent = -float(np.sum(p*np.log(p+eps)))
        hhi = float(np.sum(p**2))
        hhi_norm = (hhi - 1/n) / (1 - 1/n) if n>1 else np.nan
    else:
        ent, hhi_norm = np.nan, np.nan
    return {
        "network": label, "year": year,
        "n": n, "m": m, "density": density, "avg_degree": avg_degree,
        "degree_var": degree_var, "degree_cv": degree_cv, "degree_gini": degree_gini,
        "transitivity": transitivity, "avg_clustering": avg_clust,
        "algebraic_connectivity_gcc": alg_conn, "global_efficiency": global_eff,
        "modularity_Q": Q, "num_communities": num_comm,
        "eigenvector_entropy": ent, "eigenvector_hhi_norm": hhi_norm,
    }

def save_matrix(ids, M, path_csv: Path):
    pd.DataFrame(M, index=ids, columns=ids).to_csv(path_csv, encoding="utf-8-sig")

def save_edgelist(ids, W, path_csv: Path, binary=False, weight_threshold=1e-12):
    n = len(ids); rows = []
    for i in range(n):
        for j in range(i+1, n):
            w = W[i, j]
            if binary:
                if w > 0: rows.append((ids[i], ids[j], 1))
            else:
                if w > weight_threshold: rows.append((ids[i], ids[j], float(w)))
    pd.DataFrame(rows, columns=["source","target","weight"]).to_csv(path_csv, index=False, encoding="utf-8-sig")

# ========================= 主流程（逐年） =========================

def process_one_year(year: int):
    # 1) 读 MPA & 空间网络
    mpa_df = load_mpa_attrs(DATA_DIR / f"{year}MPA.xlsx")
    mpa_df['MPA_ID'] = mpa_df['MPA_ID'].astype(str)
    coords = mpa_df.set_index('MPA_ID')[['lon','lat']]

    ids_s, W_s, D_s, B_s = build_spatial_W_topk(coords, k=K_SPATIAL_NEIGHBORS)

    # 2) 读政策 & 制度网络权重
    pol = load_policy(year)
    Xw, Xb, ids_i, pol_ids, levels = build_policy_coverage(mpa_df, pol)
    assert ids_i == ids_s, "MPA 顺序不一致：空间与制度"

    # 3) 制度权重 Wi：co-mention -> soft -> UICS（P 用于 A* 与拓扑；Wi 用于差分）
    W_co = institutional_co_mention_from_Xw(Xw)
    C_soft, _ = soft_coverage_matrix(W_co, q=SOFT_TAU_Q)
    K, NonNat = build_K_and_nonNat(Xb, levels)
    P = uics_unified_score(C_soft, K, NonNat, k0=K0, level_base=LEVEL_BASE)

    # 4) 公共子图：仅要求有坐标（纳入零覆盖）
    mask_public = (np.isfinite(coords['lon'])) & (np.isfinite(coords['lat']))
    public_ids = list(coords.index[mask_public].astype(str))
    idx = [ids_s.index(i) for i in public_ids]

    # 5) 双网络差分权重（Ws 与 Wi 同口径：对称 + 行归一）
    Ws = W_s[np.ix_(idx, idx)].copy()
    Ci = C_soft[np.ix_(idx, idx)].copy()
    Ci = 0.5*(Ci + Ci.T)
    Ci = (Ci.T / (Ci.sum(axis=1) + EPS)).T
    Wi = Ci

    # 6) M_raw、空间二值度数（SC_deg）、覆盖统计
    M_raw = dual_diff_Mi_raw(Ws, Wi)
    Bsp = (Ws > 0).astype(int)
    SC_deg = Bsp.sum(axis=1)

    # 覆盖：是否“有任何覆盖”（不分层级）
    covered_any_all = (Xb.sum(axis=1) > 0).astype(int)
    covered_any = covered_any_all[idx].astype(int)

    # 非国字号邻接度（可留作参考）
    nonnat_mask = np.array([1 if levels[c] != 'National' else 0 for c in range(len(levels))], dtype=bool)
    K_nonNat = (Xb[:, nonnat_mask] @ Xb[:, nonnat_mask].T) if nonnat_mask.sum()>0 else np.zeros((Xb.shape[0], Xb.shape[0]), dtype=int)
    np.fill_diagonal(K_nonNat, 0)
    deg_nonNat = K_nonNat[np.ix_(idx, idx)].sum(axis=1)

    # 制度二值网络（强边）
    A_bin = binary_from_density_exact(P[np.ix_(idx, idx)], target_density=TARGET_BINARY_DENSITY)

    # 7) 汇总节点表（单年）
    sea_series  = mpa_df.set_index('MPA_ID').loc[public_ids, 'sea_region'].astype(str)
    prov_series = mpa_df.set_index('MPA_ID').loc[public_ids, 'province'].astype(str)

    nodes = pd.DataFrame({
        'MPA_ID': public_ids,
        'M_raw': M_raw,
        'SC_deg': SC_deg,
        'deg_nonNat': deg_nonNat,
        'SeaRegion': sea_series.values,
        'province': prov_series.values,
        'covered_any': covered_any
    })
    nodes.to_csv(OUT_DIR/"nodes"/f"misfit_dual_diff_{year}.csv", index=False, encoding='utf-8-sig')

    # 8) 保存矩阵与边表
    save_matrix(public_ids, Ws,    OUT_DIR/"matrices"/f"spatial_W_{year}_public.csv")
    save_matrix(public_ids, Wi,    OUT_DIR/"matrices"/f"inst_W_{year}_public.csv")
    save_matrix(public_ids, A_bin, OUT_DIR/"matrices"/f"inst_Astar_{year}_public.csv")

    # ArcGIS edgelists
    save_edgelist(public_ids, Ws,    OUT_DIR/"edgelists"/f"spatial_edges_full_{year}.csv", binary=False)
    save_edgelist(public_ids, A_bin, OUT_DIR/"edgelists"/f"inst_edges_strong_{year}.csv", binary=True)

    # 拓扑（空间二值：Bsp；制度二值：A_bin）
    topo_sp = topology_metrics_from_adj(Bsp, public_ids, label='spatial',         year=year)
    topo_in = topology_metrics_from_adj(A_bin, public_ids, label='institutional', year=year)

    return {
        'year': year,
        'public_ids': public_ids,
        'Ws': Ws,
        'Wi': Wi,
        'Bsp': Bsp,
        'A_bin': A_bin,
        'nodes_df': nodes,
        'topo_sp': topo_sp,
        'topo_in': topo_in
    }

# ========================= 统一 M 与分级 =========================

def grade_q50_q75_q90(series: pd.Series):
    pos = series[series > 0]
    if len(pos) == 0:
        return pd.Categorical(["None"]*len(series),
                              categories=['Minor','Moderate','Major','Critical','None'])
    q50, q75, q90 = pos.quantile([0.50, 0.75, 0.90]).tolist()
    eps = 1e-12
    bins   = [0.0, max(q50,eps), max(q75,q50+eps), max(q90,q75+eps), float('inf')]
    labels = ['Minor','Moderate','Major','Critical']
    return pd.cut(series.replace(0, np.nan), bins=bins, labels=labels, include_lowest=False) \
             .cat.add_categories('None').fillna('None')

def build_unified_and_fill_M(results_by_year: dict):
    """三期合并归一 M，并导出统一分级表。返回各年的 nodes_df（带 M）。"""
    pool = []
    for y in YEARS:
        dfy = results_by_year[y]['nodes_df'][['MPA_ID','M_raw']].assign(Year=y)
        pool.append(dfy)
    pool_df = pd.concat(pool, ignore_index=True)
    v = pool_df['M_raw'].astype(float)
    mn, mx = float(v.min()), float(v.max())
    pool_df['M'] = (v - mn) / (mx - mn) if mx > mn else 0.0

    # 回填各年
    for y in YEARS:
        nd = results_by_year[y]['nodes_df']
        nd = nd.merge(pool_df.loc[pool_df['Year']==y, ['MPA_ID','M']], on='MPA_ID', how='left')
        results_by_year[y]['nodes_df'] = nd
        nd[['MPA_ID','M_raw','M','SC_deg','deg_nonNat','SeaRegion','province','covered_any']].to_csv(
            OUT_DIR/"nodes"/f"misfit_dual_diff_{y}.csv", index=False, encoding='utf-8-sig'
        )
    # 统一表（分级）
    all_ids = sorted(set().union(*[set(results_by_year[y]['nodes_df']['MPA_ID']) for y in YEARS]), key=str)
    uni = pd.DataFrame({'MPA_ID': all_ids})
    sea_union = {}
    for y in YEARS:
        md = results_by_year[y]['nodes_df'].set_index('MPA_ID')['SeaRegion'].astype(str).to_dict()
        sea_union.update(md)
    uni['SeaRegion'] = uni['MPA_ID'].map(sea_union)
    for y in YEARS:
        tmp = results_by_year[y]['nodes_df'][['MPA_ID','M']].rename(columns={'M':f'M_{y}'})
        uni = uni.merge(tmp, on='MPA_ID', how='left')
    for y in YEARS:
        uni[f'misfit_grade_{y}'] = grade_q50_q75_q90(uni[f'M_{y}'].fillna(0))
    uni.to_csv(OUT_DIR/"nodes"/"misfit_nodes_all_years_unified.csv", index=False, encoding='utf-8-sig')

# ========================= 面板与 2015 标签 =========================

def make_panel(results_by_year: dict, high_pctl: float, sc_gate_q: float, out_path: Path):
    """生成一版面板（Critical-only @ 指定分位 + SC门槛），并写出。返回简要汇总（2020/2025）。"""
    rows, stats_rows = [], []
    for t in PANEL_TARGET_YEARS:
        t_prev = YEARS[YEARS.index(t)-1]
        df_t = results_by_year[t]['nodes_df'].copy()

        # A) SC gate（海区内分位，用空间二值度数 SC_deg）
        gates = df_t.groupby('SeaRegion')['SC_deg'].quantile(sc_gate_q)
        df_t['SC_gate'] = df_t.apply(lambda r: r['SC_deg'] >= gates.get(r['SeaRegion'], r['SC_deg']), axis=1)

        # B) Critical（分海区、分年）
        tau_by_sea = df_t.groupby('SeaRegion')['M'].quantile(high_pctl).to_dict()
        def _y_soft(r):
            tau = max(ABS_BASELINE, tau_by_sea.get(r['SeaRegion'], 0.0))
            return int((r['M'] > tau) and r['SC_gate'])
        df_t['Y_soft'] = df_t.apply(_y_soft, axis=1)

        # C) Y_t（Critical-only）
        df_t['Y_t'] = df_t['Y_soft'].astype(int)

        # 结构项来自 t-1 的制度二值网络
        ids_prev = results_by_year[t_prev]['public_ids']
        A_prev = results_by_year[t_prev]['A_bin']
        n_prev = A_prev.shape[0]
        deg_prev = pd.Series(A_prev.sum(axis=1), index=ids_prev)
        logdeg_prev = np.log1p(deg_prev)
        tri_prev = pd.Series(triangles_per_node(A_prev), index=ids_prev)
        two_open_prev = pd.Series((A_prev.dot(A_prev) * (1 - A_prev) * (1 - np.eye(n_prev))).sum(axis=1), index=ids_prev)
        G_prev = nx.from_numpy_array(A_prev)
        bw_dict = nx.betweenness_centrality(G_prev, normalized=True)
        bw_vals = pd.Series({ids_prev[i]: bw for i,(node,bw) in enumerate(bw_dict.items())})
        bw_z = (bw_vals - bw_vals.mean()) / (bw_vals.std(ddof=0) if bw_vals.std(ddof=0)>0 else 1.0)

        prev = results_by_year[t_prev]['nodes_df']
        prev_pos = prev['M'][prev['M']>0]
        Q3_prev = float(prev_pos.quantile(0.75)) if len(prev_pos)>0 else 0.0
        tau_high = max(Q3_prev, 0.60)
        high_prev = (prev.set_index('MPA_ID')['M'] > tau_high).astype(int)

        out = pd.DataFrame({
            'Year': t,
            'MPA_ID': df_t['MPA_ID'],
            'Y_t': df_t['Y_t'].astype(int),
            'high_misfit_t1': df_t['MPA_ID'].map(high_prev).fillna(0).astype(int).values,
            'logdeg_t1': df_t['MPA_ID'].map(logdeg_prev).fillna(0).values,
            'triangles_t1': df_t['MPA_ID'].map(tri_prev).fillna(0).values,
            'two_path_open_t1': df_t['MPA_ID'].map(two_open_prev).fillna(0).values,
            'corridor_betweenness_t1': df_t['MPA_ID'].map(bw_vals).fillna(0).values,
            'corridor_betweenness_t1_z': df_t['MPA_ID'].map(bw_z).fillna(0).values,
            'SeaRegion': df_t['SeaRegion'],
            'province': df_t['province'],
            'SC_gate': df_t['SC_gate'].astype(int),
            'covered_any_t': df_t.get('covered_any', pd.Series([0]*len(df_t))).astype(int),
            'M': df_t['M']
        })
        rows.append(out)

        total = int(len(df_t))
        ysum  = int(df_t['Y_t'].sum())
        stats_rows.append({
            "pctl": high_pctl, "scq": sc_gate_q, "year": t,
            "total": total, "Y_t": ysum, "rate": (ysum/total if total else np.nan)
        })

    panel = pd.concat(rows, ignore_index=True)
    panel.to_csv(out_path, index=False, encoding='utf-8-sig')
    return pd.DataFrame(stats_rows)
# ========================= 稳健性面板导出（新增，不改主面板） =========================

def _row_normalize(A: np.ndarray, eps=1e-12):
    s = A.sum(axis=1, keepdims=True)
    s[s <= 0] = 1.0
    return A / s

def _z(series: pd.Series):
    x = series.astype(float)
    mu, sd = x.mean(), x.std(ddof=0)
    return (x - mu) / (sd if sd > 0 else 1.0)

def make_robust_export(results_by_year: dict, high_pctl: float, sc_gate_q: float, out_path: Path):
    """
    生成“稳健性检验专用”面板：把需要检验的所有 RHS 项一次性导出，
    并且用主面板 alaam_nodes_panel.csv 中的真实 Y_t 进行填充（避免全 0）。
    """
    def _row_normalize(A: np.ndarray, eps=1e-12):
        s = A.sum(axis=1, keepdims=True)
        s[s <= 0] = 1.0
        return A / s

    def _z(series: pd.Series):
        x = pd.to_numeric(series, errors='coerce').astype(float)
        mu, sd = x.mean(), x.std(ddof=0)
        return (x - mu) / (sd if sd > 0 else 1.0)

    robust_rows = []

    # 读取主面板，建立 (Year, MPA_ID) → Y_t 的映射
    main_panel_csv = OUT_DIR/"alaam"/"alaam_nodes_panel.csv"
    if not main_panel_csv.exists():
        raise FileNotFoundError(f"找不到主面板 {main_panel_csv}，请先运行 make_panel（主口径）")
    mp = pd.read_csv(main_panel_csv, dtype={"MPA_ID": str})
    mp = mp.rename(columns={"Year":"year"})
    y_map = mp.set_index(["year","MPA_ID"])["Y_t"].to_dict()

    for t in PANEL_TARGET_YEARS:
        t_prev = YEARS[YEARS.index(t)-1]

        # —— 当前期节点表（用于映射）——
        df_t = results_by_year[t]['nodes_df'].copy()
        ids_t = df_t['MPA_ID'].astype(str).tolist()

        # —— 上一期准备（空间邻接、结构项、走廊性、high_prev）——
        ids_prev = [str(x) for x in results_by_year[t_prev]['public_ids']]
        A_prev   = results_by_year[t_prev]['A_bin']
        B_prev   = results_by_year[t_prev]['Bsp']
        n_prev   = len(ids_prev)

        deg_prev = pd.Series(A_prev.sum(axis=1), index=ids_prev)
        logdeg_prev = np.log1p(deg_prev)
        tri_prev = pd.Series(triangles_per_node(A_prev), index=ids_prev)
        two_open_prev = pd.Series(
            (A_prev.dot(A_prev) * (1 - A_prev) * (1 - np.eye(n_prev))).sum(axis=1),
            index=ids_prev
        )
        G_prev = nx.from_numpy_array(A_prev)
        bw_vals = nx.betweenness_centrality(G_prev, normalized=True)
        bw_prev = pd.Series({ids_prev[i]: bw for i,(node,bw) in enumerate(bw_vals.items())})
        bw_prev_z = _z(bw_prev).rename('corridor_betweenness_t1_z')

        # 上一期“高错配”（与主面板口径一致）
        prev = results_by_year[t_prev]['nodes_df']
        prev['MPA_ID'] = prev['MPA_ID'].astype(str)
        prev_pos = prev['M'][prev['M']>0]
        Q3_prev = float(prev_pos.quantile(0.75)) if len(prev_pos)>0 else 0.0
        tau_high = max(Q3_prev, 0.60)
        high_prev = (prev.set_index('MPA_ID')['M'] > tau_high).astype(int)

        # 空间滞后（上一期 Bsp 经行标准化，对 high_prev 求邻均）
        B_prev_norm = _row_normalize(B_prev.astype(float))
        hp_vec = pd.Series(high_prev, index=ids_prev).reindex(ids_prev).fillna(0).values
        spat_lag_prev = pd.Series(B_prev_norm.dot(hp_vec), index=ids_prev, name='spatial_lag_misfit_t1')
        spat_lag_prev_z = _z(spat_lag_prev).rename('spatial_lag_misfit_t1_z')

        # 省域 LOOI（leave-one-out 数量）
        prev_prov = results_by_year[t_prev]['nodes_df'].set_index('MPA_ID')['province'].astype(str)
        prev_prov.index = prev_prov.index.astype(str)
        grp_sum = pd.Series(high_prev).groupby(prev_prov).sum()
        looi_raw = prev_prov.map(grp_sum).reindex(high_prev.index) - pd.Series(high_prev)
        looi_raw = looi_raw.astype(float).rename('z_prov_count_looi_t1')
        looi_z = _z(looi_raw).rename('z_prov_count_looi_t1_z')

        # two-path 正交 & Q4
        design = pd.DataFrame({'const':1.0}, index=ids_prev)
        design['logdeg_t1'] = logdeg_prev.reindex(ids_prev).fillna(0).values
        design['corridor_betweenness_t1_z'] = bw_prev_z.reindex(ids_prev).fillna(0).values
        design['spatial_lag_misfit_t1_z']   = spat_lag_prev_z.reindex(ids_prev).fillna(0).values
        y_two = two_open_prev.reindex(ids_prev).astype(float).fillna(0).values
        X = design[['const','logdeg_t1','corridor_betweenness_t1_z','spatial_lag_misfit_t1_z']].values
        try:
            beta, *_ = np.linalg.lstsq(X, y_two, rcond=None)
            resid = y_two - X.dot(beta)
        except Exception:
            resid = y_two - np.nanmean(y_two)
        two_open_ortho_z = _z(pd.Series(resid, index=ids_prev)).rename('two_path_open_t1_ortho_z')
        q4 = float(pd.Series(y_two, index=ids_prev).quantile(0.75))
        two_topQ4 = (pd.Series(y_two, index=ids_prev) > q4).astype(int).rename('two_path_topQ4_ortho_t1')

        # —— 组装到当前期（映射）——
        out = pd.DataFrame({
            'Year': t,
            'MPA_ID': df_t['MPA_ID'].astype(str),
            # 关键：从主面板映射真实 Y_t
            'Y_t': [ y_map.get((t, mid), 0) for mid in df_t['MPA_ID'].astype(str) ],
            'high_misfit_t1': df_t['MPA_ID'].map(high_prev).fillna(0).astype(int).values,
            'logdeg_t1': df_t['MPA_ID'].map(logdeg_prev).fillna(0).values,
            'corridor_betweenness_t1': df_t['MPA_ID'].map(bw_prev).fillna(0).values,
            'corridor_betweenness_t1_z': df_t['MPA_ID'].map(bw_prev_z).fillna(0).values,
            'two_path_open_t1': df_t['MPA_ID'].map(two_open_prev).fillna(0).values,
            'spatial_lag_misfit_t1': df_t['MPA_ID'].map(spat_lag_prev).fillna(0).values,
            'spatial_lag_misfit_t1_z': df_t['MPA_ID'].map(spat_lag_prev_z).fillna(0).values,
            'z_prov_count_looi_t1': df_t['MPA_ID'].map(looi_raw).fillna(0).values,
            'z_prov_count_looi_t1_z': df_t['MPA_ID'].map(looi_z).fillna(0).values,
            'two_path_open_t1_ortho_z': df_t['MPA_ID'].map(two_open_ortho_z).fillna(0).values,
            'two_path_topQ4_ortho_t1': df_t['MPA_ID'].map(two_topQ4).fillna(0).astype(int).values,
            'sea_region': results_by_year[t]['nodes_df'].set_index('MPA_ID').loc[df_t['MPA_ID'],'SeaRegion'].values,
            'province': results_by_year[t]['nodes_df'].set_index('MPA_ID').loc[df_t['MPA_ID'],'province'].values,
        })
        out['logdeg_t1_z'] = _z(out['logdeg_t1']).values
        robust_rows.append(out)

    panel_robust = pd.concat(robust_rows, ignore_index=True)
    panel_robust.to_csv(out_path, index=False, encoding='utf-8-sig')
    print("Robustness-ready panel saved:", out_path, f"(rows={len(panel_robust)})")
    # 简单自检
    pos_rate = panel_robust['Y_t'].mean() if 'Y_t' in panel_robust.columns else None
    print(f"[selfcheck] Y_t variation: unique={panel_robust['Y_t'].nunique()}  pos_rate={pos_rate:.4f}")
    return panel_robust




def make_2015_labels(results_by_year: dict, high_pctl: float, sc_gate_q: float, out_path: Path):
    """为 2015 生成 Critical-only @ 指定分位 + SC门槛 的标签，并写出；返回计数摘要。"""
    df = results_by_year[2015]['nodes_df'].copy()
    gates = df.groupby('SeaRegion')['SC_deg'].quantile(sc_gate_q)
    df['SC_gate'] = df.apply(lambda r: r['SC_deg'] >= gates.get(r['SeaRegion'], r['SC_deg']), axis=1)
    tau_by_sea = df.groupby('SeaRegion')['M'].quantile(high_pctl).to_dict()
    def _crit(r):
        tau = max(ABS_BASELINE, tau_by_sea.get(r['SeaRegion'], 0.0))
        return int(r['M'] > tau)
    df['Critical'] = df.apply(_crit, axis=1)
    df['Y_2015'] = (df['SC_gate'].astype(int) & df['Critical'].astype(int)).astype(int)
    df[['MPA_ID','SeaRegion','province','M','SC_deg','SC_gate','Critical','Y_2015']].to_csv(out_path, index=False, encoding='utf-8-sig')
    total = int(len(df)); ysum = int(df['Y_2015'].sum())
    return {"pctl": high_pctl, "scq": sc_gate_q, "year": 2015, "total": total, "Y_t": ysum, "rate": (ysum/total if total else np.nan)}

# ========================= 运行入口 =========================

def main():
    warnings.filterwarnings('ignore')

    # 逐年处理
    results = {}
    topo_rows = []
    for y in YEARS:
        r = process_one_year(y)
        results[y] = r
        topo_rows.append(r['topo_sp'])
        topo_rows.append(r['topo_in'])
        print(f"[{y}] public_n={len(r['public_ids'])}  wrote nodes/matrices/edgelists.")

    # 拓扑汇总
    topo_df = pd.DataFrame(topo_rows)
    topo_df.to_csv(OUT_DIR/"summary"/"topology_summary.csv", index=False, encoding='utf-8-sig')

    # 三期合并归一并写统一分级
    build_unified_and_fill_M(results)

    # —— 基准口径（P=0.85, SC=Q40） ——
    main_panel_path = OUT_DIR/"alaam"/"alaam_nodes_panel.csv"
    stats_main_2020_2025 = make_panel(results, HIGH_PCTL_MAIN, SC_GATE_Q_MAIN, main_panel_path)
    print("Main panel saved:", main_panel_path)
    main_2015_path = OUT_DIR/"alaam"/f"y2015_labels__critical_p{int(HIGH_PCTL_MAIN*100)}_scq{int(SC_GATE_Q_MAIN*100)}.csv"
    stats_main_2015 = make_2015_labels(results, HIGH_PCTL_MAIN, SC_GATE_Q_MAIN, main_2015_path)
    print("2015 labels (main) saved:", main_2015_path)
    # —— 稳健性专用面板（把检验所需 RHS 项一次性写齐） ——
    robust_panel_path = OUT_DIR / "alaam" / "alaam_nodes_panel__robust.csv"
    _ = make_robust_export(results, HIGH_PCTL_MAIN, SC_GATE_Q_MAIN, robust_panel_path)

    # —— 灵敏度网格（PCTL × SC_Q） ——
    stats_all = []
    # 先加入基准
    stats_all.append(stats_main_2020_2025.assign(spec="critical_p85_scq40"))
    stats_all.append(pd.DataFrame([stats_main_2015]).assign(spec="critical_p85_scq40"))
    for pctl in SENS_PCTLS:
        for scq in SENS_SCQS:
            tag = f"critical_p{int(round(pctl*100))}_scq{int(round(scq*100))}"
            panel_path = OUT_DIR/"alaam"/f"alaam_nodes_panel__{tag}.csv"
            st_2020_2025 = make_panel(results, pctl, scq, panel_path)
            print("Sensitivity panel saved:", panel_path)
            y2015_path = OUT_DIR/"alaam"/f"y2015_labels__{tag}.csv"
            st_2015 = make_2015_labels(results, pctl, scq, y2015_path)
            print("2015 labels saved:", y2015_path)
            stats_all.append(st_2020_2025.assign(spec=tag))
            stats_all.append(pd.DataFrame([st_2015]).assign(spec=tag))

    # 汇总表
    stats_df = pd.concat(stats_all, ignore_index=True)
    stats_df = stats_df[['spec','pctl','scq','year','total','Y_t','rate']].sort_values(['spec','year'])
    stats_df.to_csv(OUT_DIR/"summary"/"sensitivity_grid_summary.csv", index=False, encoding="utf-8-sig")
    print("Sensitivity summary saved:", OUT_DIR/"summary"/"sensitivity_grid_summary.csv")

    # ========================= ROBUST 面板导出（额外 RHS） =========================
    # 要求：在 main() 中放在“基准面板保存/灵敏度完成”之后
    def _z_1d(x: pd.Series) -> pd.Series:
        x = pd.to_numeric(x, errors="coerce")
        mu = x.mean()
        sd = x.std(ddof=0)
        return (x - mu) / (sd if (sd and sd > 0) else 1.0)

    def _safe_div(num: pd.Series, den: pd.Series) -> pd.Series:
        den = den.replace(0, np.nan)
        out = num / den
        return out.fillna(0.0)

    # 读取刚才写出的“基准面板”，在其基础上补充稳健性 RHS
    panel_main = pd.read_csv(main_panel_path)

    robust_rows = []
    for t in PANEL_TARGET_YEARS:
        t_prev = YEARS[YEARS.index(t) - 1]

        # --- 取 t 面板行（确保与主面板一致） ---
        dt = panel_main[panel_main["Year"] == t].copy()

        # --- 准备 t-1 年度的基础素材 ---
        ids_prev = results[t_prev]['public_ids']
        # 制度二值网络度（总度）：
        A_prev = results[t_prev]['A_bin']
        deg_prev = pd.Series(A_prev.sum(axis=1), index=ids_prev)

        # 空间二值网络（用于空间滞后）：
        Bsp_prev = results[t_prev]['Bsp']
        # t-1 年节点的 M 值（已“合并归一”）
        prev_nodes = results[t_prev]['nodes_df'].set_index('MPA_ID')
        M_prev = prev_nodes['M'].astype(float).reindex(ids_prev).fillna(0.0)

        # 1) spatial_lag_misfit_t1（用二值邻接的行平均；再 z）
        deg_sp_prev = Bsp_prev.sum(axis=1)
        # 平均邻居 M： (B * M) / 度
        sl_raw = (Bsp_prev @ M_prev.values)
        sl_avg = _safe_div(pd.Series(sl_raw, index=ids_prev), pd.Series(deg_sp_prev, index=ids_prev))
        sl_map = sl_avg.to_dict()
        dt['spatial_lag_misfit_t1'] = dt['MPA_ID'].map(sl_map).fillna(0.0)

        # 2) logdeg_t1_z（对 log(制度度+1) 年内标准化）
        logdeg_prev = np.log1p(deg_prev)
        dt['logdeg_t1'] = dt['MPA_ID'].map(logdeg_prev).fillna(0.0)
        dt['logdeg_t1_z'] = dt.groupby('Year')['logdeg_t1'].transform(_z_1d)

        # 3) z_prov_count_looi_t1_z（省级 leave-one-out 高错配占比）
        #    先用主面板同一口径构造 t-1 的 high_misfit（与 make_panel 一致）
        prev_df_all = results[t_prev]['nodes_df'].copy()
        prev_pos = prev_df_all['M'][prev_df_all['M'] > 0]
        Q3_prev = float(prev_pos.quantile(0.75)) if len(prev_pos) > 0 else 0.0
        tau_high = max(Q3_prev, 0.60)
        high_prev = (prev_df_all.set_index('MPA_ID')['M'] > tau_high).astype(int)
        prov_prev = prev_df_all.set_index('MPA_ID')['province'].astype(str)

        # 省级内的总高错配数与样本量
        grp = pd.DataFrame({
            'high': high_prev,
            'prov': prov_prev
        }).reset_index().groupby('prov', as_index=False).agg(
            total_high=('high', 'sum'),
            total_cnt =('high', 'count')
        ).set_index('prov')

        # 每个节点的 LOOI = (省内高错配总数 - 自身高错配) / (省内样本数 - 1)
        looi_map = {}
        for i in ids_prev:
            prov_i = prov_prev.get(i, 'Unknown')
            if prov_i in grp.index:
                TH = grp.at[prov_i, 'total_high']
                TC = grp.at[prov_i, 'total_cnt']
                h_i = int(high_prev.get(i, 0))
                den = max(TC - 1, 1)  # 防 0
                looi_map[i] = float(max(TH - h_i, 0) / den)
            else:
                looi_map[i] = 0.0

        dt['z_prov_count_looi_t1'] = dt['MPA_ID'].map(looi_map).fillna(0.0)
        dt['z_prov_count_looi_t1_z'] = dt.groupby('Year')['z_prov_count_looi_t1'].transform(_z_1d)

        # 4) nonNat_share_t1_z（非国字号邻接占比 = deg_nonNat_{t-1} / degree_{t-1}）
        deg_nonNat_prev = prev_nodes['deg_nonNat'].reindex(ids_prev).fillna(0.0)
        share_prev = _safe_div(deg_nonNat_prev, deg_prev.reindex(ids_prev).fillna(0.0))
        share_map = share_prev.to_dict()
        dt['nonNat_share_t1'] = dt['MPA_ID'].map(share_map).fillna(0.0)
        dt['nonNat_share_t1_z'] = dt.groupby('Year')['nonNat_share_t1'].transform(_z_1d)

        # 5) corridor_betweenness_t1_z 若主面板未写 z，则从原值得到 z
        if 'corridor_betweenness_t1_z' not in dt.columns and 'corridor_betweenness_t1' in dt.columns:
            dt['corridor_betweenness_t1_z'] = dt.groupby('Year')['corridor_betweenness_t1'].transform(_z_1d)
        dt['corridor_betweenness_t1_z'] = dt.get('corridor_betweenness_t1_z', 0.0)

        # 6) spatial_lag_misfit_t1_z （对 1) 的量按年 z）
        dt['spatial_lag_misfit_t1_z'] = dt.groupby('Year')['spatial_lag_misfit_t1'].transform(_z_1d)

        # 7) two_path_open_t1_ortho_z（对 two_path_open_t1 残差正交：基于 logdeg/corridor/spatial-lag）
        y = pd.to_numeric(dt['two_path_open_t1'], errors='coerce').fillna(0.0).values.reshape(-1, 1)
        X_cols = ['logdeg_t1', 'corridor_betweenness_t1_z', 'spatial_lag_misfit_t1_z']
        X = np.column_stack([pd.to_numeric(dt[c], errors='coerce').fillna(0.0).values for c in X_cols])
        X = np.column_stack([np.ones(len(dt)), X])  # 加截距
        # 最小二乘求残差
        try:
            beta, *_ = np.linalg.lstsq(X, y, rcond=None)
            resid = (y - X @ beta).ravel()
        except Exception:
            resid = y.ravel()  # 兜底：直接用原值
        dt['two_path_open_t1_ortho'] = resid
        dt['two_path_open_t1_ortho_z'] = dt.groupby('Year')['two_path_open_t1_ortho'].transform(_z_1d)

        # 8) two_path_topQ4_ortho_t1（基于正交残差的上四分位虚拟）
        q3 = np.nanquantile(dt['two_path_open_t1_ortho'].values, 0.75) if len(dt) else 0.0
        dt['two_path_topQ4_ortho_t1'] = (dt['two_path_open_t1_ortho'] > q3).astype(int)

        robust_rows.append(dt)

    panel_robust = pd.concat(robust_rows, ignore_index=True)

    # 仅保留稳健性所需核心列 + 主面板必要列（不改动主面板）
    keep_cols = [
        'Year','MPA_ID','Y_t','SeaRegion','province',
        'high_misfit_t1',
        'logdeg_t1','logdeg_t1_z',
        'corridor_betweenness_t1','corridor_betweenness_t1_z',
        'two_path_open_t1','two_path_open_t1_ortho','two_path_open_t1_ortho_z','two_path_topQ4_ortho_t1',
        'spatial_lag_misfit_t1','spatial_lag_misfit_t1_z',
        'z_prov_count_looi_t1','z_prov_count_looi_t1_z',
        'nonNat_share_t1','nonNat_share_t1_z',
        'SC_gate','covered_any_t','M'
    ]
    # 若主面板中有附加列（比如 triangles_t1 / two_path_open_t1 之类），也一并保留
    for c in panel_main.columns:
        if c not in keep_cols and c not in panel_robust.columns:
            panel_robust[c] = panel_main[c]

    # 对齐列顺序
    final_cols = [c for c in keep_cols if c in panel_robust.columns] + \
                 [c for c in panel_robust.columns if c not in keep_cols]
    panel_robust = panel_robust[final_cols]

    robust_path = OUT_DIR/"alaam"/"alaam_nodes_panel__robust.csv"
    panel_robust.to_csv(robust_path, index=False, encoding="utf-8-sig")
    print("[ROBUST] robust panel saved:", robust_path)

if __name__ == '__main__':
    main()
